#
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Predicting 3d poses from 2d joints"""
from npu_bridge.npu_init import *
from tensorflow.core.protobuf.rewriter_config_pb2 import RewriterConfig

import os
import sys
import time
import copy

import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf

import cameras
import data_utils
import linear_model
import procrustes
import viz

#from help_modelarts import modelarts_result2obs

## Required parameters
tf.app.flags.DEFINE_string("result", "result", "The result directory where the model checkpoints will be written.")
tf.app.flags.DEFINE_string("obs_dir", "obs://simple-pose-baseline-dataset", "Obs result path")

tf.app.flags.DEFINE_float("learning_rate", 1e-3, "Learning rate")
tf.app.flags.DEFINE_float("dropout", 1, "Dropout keep probability. 1 means no dropout")
tf.app.flags.DEFINE_integer("batch_size", 64, "Batch size to use during training")
tf.app.flags.DEFINE_integer("epochs", 200, "How many epochs we should train for")
tf.app.flags.DEFINE_boolean("camera_frame", False, "Convert 3d poses to camera coordinates")
tf.app.flags.DEFINE_boolean("max_norm", False, "Apply maxnorm constraint to the weights")
tf.app.flags.DEFINE_boolean("batch_norm", False, "Use batch_normalization")

# Data loading
tf.app.flags.DEFINE_boolean("predict_14", False, "predict 14 joints")
tf.app.flags.DEFINE_string("action", "All", "The action to train on. 'All' means all the actions")

# Architecture
tf.app.flags.DEFINE_integer("linear_size", 1024, "Size of each model layer.")
tf.app.flags.DEFINE_integer("num_layers", 2, "Number of layers in the model.")
tf.app.flags.DEFINE_boolean("residual", False, "Whether to add a residual connection every 2 layers")

# Evaluation
tf.app.flags.DEFINE_boolean("procrustes", False, "Apply procrustes analysis at test time")
tf.app.flags.DEFINE_boolean("evaluateActionWise", False, "The dataset to use either h36m or heva")

# Directories
tf.app.flags.DEFINE_string("cameras_path", "data/h36m/metadata.xml", "File with h36m metadata, including cameras")
tf.app.flags.DEFINE_string("data_dir", "data/h36m/", "Data directory")
tf.app.flags.DEFINE_string("train_dir", "experiments", "Training directory.")

# Train or load
tf.app.flags.DEFINE_boolean("sample", False, "Set to True for sampling.")
tf.app.flags.DEFINE_boolean("use_cpu", False, "Whether to use the CPU")
tf.app.flags.DEFINE_integer("load", 0, "Try to load a previous checkpoint.")
tf.app.flags.DEFINE_boolean("profiling", True, "Need to profiling.")

# Misc
tf.app.flags.DEFINE_boolean("use_fp16", False, "Train using fp16 instead of fp32.")

FLAGS = tf.app.flags.FLAGS

# train_dir = os.path.join( FLAGS.train_dir,
#   FLAGS.action,
#   'dropout_{0}'.format(FLAGS.dropout),
#   'epochs_{0}'.format(FLAGS.epochs) if FLAGS.epochs > 0 else '',
#   'lr_{0}'.format(FLAGS.learning_rate),
#   'residual' if FLAGS.residual else 'not_residual',
#   'depth_{0}'.format(FLAGS.num_layers),
#   'linear_size{0}'.format(FLAGS.linear_size),
#   'batch_size_{0}'.format(FLAGS.batch_size),
#   'procrustes' if FLAGS.procrustes else 'no_procrustes',
#   'maxnorm' if FLAGS.max_norm else 'no_maxnorm',
#   'batch_normalization' if FLAGS.batch_norm else 'no_batch_normalization',
#   'predict_14' if FLAGS.predict_14 else 'predict_17')
train_dir = FLAGS.result

print(train_dir)
summaries_dir = os.path.join(train_dir, "log")  # Directory for TB summaries

# To avoid race conditions: https://github.com/tensorflow/tensorflow/issues/7448
os.system('mkdir -p {}'.format(summaries_dir))


def create_model(session, actions, batch_size):
    """
    Create model and initialize it or load its parameters in a session

    Args
      session: tensorflow session
      actions: list of string. Actions to train/test on
      batch_size: integer. Number of examples in each batch
    Returns
      model: The created (or loaded) model
    Raises
      ValueError if asked to load a model, but the checkpoint specified by
      FLAGS.load cannot be found.
    """

    model = linear_model.LinearModel(
        FLAGS.linear_size,
        FLAGS.num_layers,
        FLAGS.residual,
        FLAGS.batch_norm,
        FLAGS.max_norm,
        batch_size,
        FLAGS.learning_rate,
        summaries_dir,
        FLAGS.predict_14,
        dtype=tf.float16 if FLAGS.use_fp16 else tf.float32)

    if FLAGS.load <= 0:
        # Create a new model from scratch
        print("Creating model with fresh parameters.")
        session.run(tf.compat.v1.global_variables_initializer())
        return model

    # Load a previously saved model
    ckpt = tf.train.get_checkpoint_state(train_dir, latest_filename="checkpoint")
    print("train_dir", train_dir)

    if ckpt and ckpt.model_checkpoint_path:
        # Check if the specific checkpoint exists
        if FLAGS.load > 0:
            if os.path.isfile(os.path.join(train_dir, "checkpoint-{0}.index".format(FLAGS.load))):
                ckpt_name = os.path.join(os.path.join(train_dir, "checkpoint-{0}".format(FLAGS.load)))
            else:
                raise ValueError("Asked to load checkpoint {0}, but it does not seem to exist".format(FLAGS.load))
        else:
            ckpt_name = os.path.basename(ckpt.model_checkpoint_path)

        print("Loading model {0}".format(ckpt_name))
        model.saver.restore(session, ckpt.model_checkpoint_path)
        return model
    else:
        print("Could not find checkpoint. Aborting.")
        raise (ValueError, "Checkpoint {0} does not seem to exist".format(ckpt.model_checkpoint_path))

    return model


def train():
    """Train a linear model for 3d pose estimation"""

    actions = data_utils.define_actions(FLAGS.action)

    number_of_actions = len(actions)

    # Load camera parameters
    SUBJECT_IDS = [1, 5, 6, 7, 8, 9, 11]
    this_file = os.path.dirname(os.path.realpath(__file__))
    rcams = cameras.load_cameras(os.path.join(this_file, "..", FLAGS.cameras_path), SUBJECT_IDS)

    # Load 3d data and load (or create) 2d projections
    train_set_3d, test_set_3d, data_mean_3d, data_std_3d, dim_to_ignore_3d, dim_to_use_3d, train_root_positions, test_root_positions = data_utils.read_3d_data(
        actions, FLAGS.data_dir, FLAGS.camera_frame, rcams, FLAGS.predict_14)

    # Read groundtruth 2D projections
    train_set_2d, test_set_2d, data_mean_2d, data_std_2d, dim_to_ignore_2d, dim_to_use_2d = data_utils.create_2d_data(
        actions, FLAGS.data_dir, rcams)
    print("done reading and normalizing data.")

    global_config = tf.ConfigProto()
    custom_op = global_config.graph_options.rewrite_options.custom_optimizers.add()
    # mix_precision
    custom_op.name = "NpuOptimizer"
    custom_op.parameter_map["use_off_line"].b = True
    custom_op.parameter_map["precision_mode"].s = tf.compat.as_bytes("allow_mix_precision")
    # Profiling
    if FLAGS.profiling:
        custom_op.parameter_map["profiling_mode"].b = True
        custom_op.parameter_map["profiling_options"].s = tf.compat.as_bytes('{"output":"/cache/profiling",'
                                                                            '"task_trace":"on",'
                                                                            '"aicpu":"on"}')
        custom_op.parameter_map["fp_point"].s = tf.compat.as_bytes("ExponentialDecay/truediv")
        custom_op.parameter_map["bp_point"].s = tf.compat.as_bytes("gradients/AddN_23")

    global_config.graph_options.rewrite_options.remapping = RewriterConfig.OFF

    # Avoid using the GPU if requested
    device_count = {"GPU": 0} if FLAGS.use_cpu else {"GPU": 1}
    with tf.compat.v1.Session(
            config=global_config
    ) as sess:

        # === Create the model ===
        print("Creating %d bi-layers of %d units." % (FLAGS.num_layers, FLAGS.linear_size))
        model = create_model(sess, actions, FLAGS.batch_size)
        model.train_writer.add_graph(sess.graph)
        print("Model created")

        # === This is the training loop ===
        step_time, loss, val_loss = 0.0, 0.0, 0.0
        current_step = 0 if FLAGS.load <= 0 else FLAGS.load + 1
        previous_losses = []

        step_time, loss = 0, 0
        current_epoch = 0
        log_every_n_batches = 100

        training_start_time = time.time()

        for _ in range(FLAGS.epochs):
            current_epoch = current_epoch + 1

            # === Load training batches for one epoch ===
            encoder_inputs, decoder_outputs = model.get_all_batches(train_set_2d, train_set_3d, FLAGS.camera_frame,
                                                                    training=True)
            nbatches = len(encoder_inputs)
            print("There are {0} train batches".format(nbatches))  #24371
            start_time, loss = time.time(), 0.

            # === Loop through all the training batches ===
            for i in range(nbatches):

                enc_in, dec_out = encoder_inputs[i], decoder_outputs[i]
                step_loss, loss_summary, lr_summary, _ = model.step(sess, enc_in, dec_out, FLAGS.dropout,
                                                                    isTraining=True)

                if (i + 1) % log_every_n_batches == 0:
                    # Print progress every log_every_n_batches batches
                    print("Working on epoch {0},  batch {1} / {2}, loss {3} ... ".format(current_epoch, i + 1, nbatches,
                                                                                        step_loss), end="")

                if (i + 1) % log_every_n_batches == 0:
                    # Log and print progress every log_every_n_batches batches
                    model.train_writer.add_summary(loss_summary, current_step)
                    model.train_writer.add_summary(lr_summary, current_step)
                    step_time = (time.time() - start_time)
                    start_time = time.time()
                    print("done in {0:.2f} ms".format(1000 * step_time / log_every_n_batches))

                loss += step_loss
                current_step += 1
                # === end looping through training batches ===

            loss = loss / nbatches
            print("=============================\n"
                  "Global step:         %d\n"
                  "Learning rate:       %.2e\n"
                  "Train loss avg:      %.4f\n"
                  "=============================" % (model.global_step.eval(),
                                                     model.learning_rate.eval(), loss))
            # === End training for an epoch ===

            # === Testing after this epoch ===
            isTraining = False

            if FLAGS.evaluateActionWise:

                print("{0:=^12} {1:=^6}".format("Action", "mm"))  # line of 30 equal signs

                cum_err = 0
                for action in actions:
                    print("{0:<12} ".format(action), end="")
                    # Get 2d and 3d testing data for this action
                    action_test_set_2d = get_action_subset(test_set_2d, action)
                    action_test_set_3d = get_action_subset(test_set_3d, action)
                    encoder_inputs, decoder_outputs = model.get_all_batches(action_test_set_2d, action_test_set_3d,
                                                                            FLAGS.camera_frame, training=False)

                    act_err, _, step_time, loss = evaluate_batches(sess, model,
                                                                    data_mean_3d, data_std_3d, dim_to_use_3d,
                                                                    dim_to_ignore_3d,
                                                                    data_mean_2d, data_std_2d, dim_to_use_2d,
                                                                    dim_to_ignore_2d,
                                                                    current_step, encoder_inputs, decoder_outputs)
                    cum_err = cum_err + act_err

                    print("{0:>6.2f}".format(act_err))

                summaries = sess.run(model.err_mm_summary, {model.err_mm: float(cum_err / float(len(actions)))})
                model.test_writer.add_summary(summaries, current_step)
                print("{0:<12} {1:>6.2f}".format("Average", cum_err / float(len(actions))))
                print("{0:=^19}".format(''))

            else:

                n_joints = 17 if not (FLAGS.predict_14) else 14
                encoder_inputs, decoder_outputs = model.get_all_batches(test_set_2d, test_set_3d, FLAGS.camera_frame,
                                                                        training=False)

                total_err, joint_err, step_time, loss = evaluate_batches(sess, model,
                                                                            data_mean_3d, data_std_3d, dim_to_use_3d,
                                                                            dim_to_ignore_3d,
                                                                            data_mean_2d, data_std_2d, dim_to_use_2d,
                                                                            dim_to_ignore_2d,
                                                                            current_step, encoder_inputs, decoder_outputs,
                                                                            current_epoch)

                print("=============================\n"
                        "Step-time (ms):      %.4f\n"
                        "Val loss avg:        %.4f\n"
                        "Val error avg (mm):  %.2f\n"
                        "=============================" % (1000 * step_time, loss, total_err))

                for i in range(n_joints):
                    # 6 spaces, right-aligned, 5 decimal places
                    print("Error in joint {0:02d} (mm): {1:>5.2f}".format(i + 1, joint_err[i]))
                print("=============================")

                # Log the error to tensorboard
                summaries = sess.run(model.err_mm_summary, {model.err_mm: total_err})
                model.test_writer.add_summary(summaries, current_step)

            # Reset global time and loss
            step_time, loss = 0, 0

            sys.stdout.flush()

            # Save the model
            print("Saving the model... ", end="")
            start_time = time.time()
            model.saver.save(sess, os.path.join(train_dir, 'checkpoint'), global_step=current_step)
            print("done in {0:.2f} ms".format(1000 * (time.time() - start_time)))

            training_end_time = time.time()
            print("The total training time is {0:.2f} ms".format(1000 * (training_end_time - training_start_time)))

            #modelarts_result2obs(FLAGS)


def get_action_subset(poses_set, action):
    """
    Given a preloaded dictionary of poses, load the subset of a particular action

    Args
      poses_set: dictionary with keys k=(subject, action, seqname),
        values v=(nxd matrix of poses)
      action: string. The action that we want to filter out
    Returns
      poses_subset: dictionary with same structure as poses_set, but only with the
        specified action.
    """
    return {k: v for k, v in poses_set.items() if k[1] == action}


def evaluate_batches(sess, model,
                     data_mean_3d, data_std_3d, dim_to_use_3d, dim_to_ignore_3d,
                     data_mean_2d, data_std_2d, dim_to_use_2d, dim_to_ignore_2d,
                     current_step, encoder_inputs, decoder_outputs, current_epoch=0):
    """
    Generic method that evaluates performance of a list of batches.
    May be used to evaluate all actions or a single action.

    Args
      sess: tensorflow session
      model: tensorflow model to run evaluation with
      data_mean_3d: the mean of the training data in 3d
      data_std_3d: the standard deviation of the training data in 3d
      dim_to_use_3d: out of all the 96 dimensions that represent a 3d body in h36m, compute results for this subset
      dim_to_ignore_3d: complelment of the above
      data_mean_2d: mean of the training data in 2d
      data_std_2d: standard deviation of the training data in 2d
      dim_to_use_2d: out of the 64 dimensions that represent a body in 2d in h35m, use this subset
      dim_to_ignore_2d: complement of the above
      current_step: training iteration step
      encoder_inputs: input for the network
      decoder_outputs: expected output for the network
      current_epoch: current training epoch
    Returns
      total_err: average mm error over all joints
      joint_err: average mm error per joint
      step_time: time it took to evaluate one batch
      loss: validation loss of the network
    """

    n_joints = 17 if not (FLAGS.predict_14) else 14
    nbatches = len(encoder_inputs)

    # Loop through test examples
    all_dists, start_time, loss = [], time.time(), 0.
    log_every_n_batches = 100
    for i in range(nbatches):

        if current_epoch > 0 and (i + 1) % log_every_n_batches == 0:
            print("Working on test epoch {0}, batch {1} / {2}".format(current_epoch, i + 1, nbatches))

        enc_in, dec_out = encoder_inputs[i], decoder_outputs[i]
        dp = 1.0  # dropout keep probability is always 1 at test time
        step_loss, loss_summary, poses3d = model.step(sess, enc_in, dec_out, dp, isTraining=False)
        loss += step_loss

        # denormalize
        enc_in = data_utils.unNormalizeData(enc_in, data_mean_2d, data_std_2d, dim_to_ignore_2d)
        dec_out = data_utils.unNormalizeData(dec_out, data_mean_3d, data_std_3d, dim_to_ignore_3d)
        poses3d = data_utils.unNormalizeData(poses3d, data_mean_3d, data_std_3d, dim_to_ignore_3d)

        # Keep only the relevant dimensions
        dtu3d = np.hstack((np.arange(3), dim_to_use_3d)) if not (FLAGS.predict_14) else dim_to_use_3d

        dec_out = dec_out[:, dtu3d]
        poses3d = poses3d[:, dtu3d]

        assert dec_out.shape[0] == FLAGS.batch_size
        assert poses3d.shape[0] == FLAGS.batch_size

        if FLAGS.procrustes:
            # Apply per-frame procrustes alignment if asked to do so
            for j in range(FLAGS.batch_size):
                gt = np.reshape(dec_out[j, :], [-1, 3])
                out = np.reshape(poses3d[j, :], [-1, 3])
                _, Z, T, b, c = procrustes.compute_similarity_transform(gt, out, compute_optimal_scale=True)
                out = (b * out.dot(T)) + c

                poses3d[j, :] = np.reshape(out, [-1, 17 * 3]) if not (FLAGS.predict_14) else np.reshape(out,
                                                                                                        [-1, 14 * 3])

        # Compute Euclidean distance error per joint
        sqerr = (poses3d - dec_out) ** 2  # Squared error between prediction and expected output
        dists = np.zeros((sqerr.shape[0], n_joints))  # Array with L2 error per joint in mm
        dist_idx = 0
        for k in np.arange(0, n_joints * 3, 3):
            # Sum across X,Y, and Z dimenstions to obtain L2 distance
            dists[:, dist_idx] = np.sqrt(np.sum(sqerr[:, k:k + 3], axis=1))
            dist_idx = dist_idx + 1

        all_dists.append(dists)
        assert sqerr.shape[0] == FLAGS.batch_size

    step_time = (time.time() - start_time) / nbatches
    loss = loss / nbatches

    all_dists = np.vstack(all_dists)

    # Error per joint and total for all passed batches
    joint_err = np.mean(all_dists, axis=0)
    total_err = np.mean(all_dists)

    return total_err, joint_err, step_time, loss


def sample():
    """Get samples from a model and visualize them"""

    actions = data_utils.define_actions(FLAGS.action)

    # Load camera parameters
    SUBJECT_IDS = [1, 5, 6, 7, 8, 9, 11]
    this_file = os.path.dirname(os.path.realpath(__file__))
    rcams = cameras.load_cameras(os.path.join(this_file, "..", FLAGS.cameras_path), SUBJECT_IDS)

    # Load 3d data and load (or create) 2d projections
    train_set_3d, test_set_3d, data_mean_3d, data_std_3d, dim_to_ignore_3d, dim_to_use_3d, train_root_positions, test_root_positions = data_utils.read_3d_data(
        actions, FLAGS.data_dir, FLAGS.camera_frame, rcams, FLAGS.predict_14)

    train_set_2d, test_set_2d, data_mean_2d, data_std_2d, dim_to_ignore_2d, dim_to_use_2d = data_utils.create_2d_data(
        actions, FLAGS.data_dir, rcams)
    print("done reading and normalizing data.")

    device_count = {"GPU": 0} if FLAGS.use_cpu else {"GPU": 1}
    with tf.compat.v1.Session(config=npu_config_proto(config_proto=tf.ConfigProto(device_count=device_count))) as sess:
        # === Create the model ===
        print("Creating %d layers of %d units." % (FLAGS.num_layers, FLAGS.linear_size))
        batch_size = 128
        model = create_model(sess, actions, batch_size)
        print("Model loaded")

        for key2d in test_set_2d.keys():

            (subj, b, fname) = key2d
            print("Subject: {}, action: {}, fname: {}".format(subj, b, fname))

            # keys should be the same if 3d is in camera coordinates
            key3d = key2d if FLAGS.camera_frame else (subj, b, '{0}.h5'.format(fname.split('.')[0]))
            # key3d = (subj, b, fname[:-3]) if (fname.endswith('-sh')) and FLAGS.camera_frame else key3d

            enc_in = test_set_2d[key2d]
            n2d, _ = enc_in.shape
            dec_out = test_set_3d[key3d]
            n3d, _ = dec_out.shape
            assert n2d == n3d

            # Split into about-same-size batches
            enc_in = np.array_split(enc_in, n2d // batch_size)
            dec_out = np.array_split(dec_out, n3d // batch_size)
            all_poses_3d = []

            for bidx in range(len(enc_in)):
                # Dropout probability 0 (keep probability 1) for sampling
                dp = 1.0
                _, _, poses3d = model.step(sess, enc_in[bidx], dec_out[bidx], dp, isTraining=False)

                # denormalize
                enc_in[bidx] = data_utils.unNormalizeData(enc_in[bidx], data_mean_2d, data_std_2d, dim_to_ignore_2d)
                dec_out[bidx] = data_utils.unNormalizeData(dec_out[bidx], data_mean_3d, data_std_3d, dim_to_ignore_3d)
                poses3d = data_utils.unNormalizeData(poses3d, data_mean_3d, data_std_3d, dim_to_ignore_3d)
                all_poses_3d.append(poses3d)

            # Put all the poses together
            enc_in, dec_out, poses3d = map(np.vstack, [enc_in, dec_out, all_poses_3d])

            # Convert back to world coordinates
            if FLAGS.camera_frame:
                N_CAMERAS = 4
                N_JOINTS_H36M = 32

                # Add global position back
                dec_out = dec_out + np.tile(test_root_positions[key3d], [1, N_JOINTS_H36M])

                # Load the appropriate camera
                subj, _, sname = key3d

                cname = sname.split('.')[1]  # <-- camera name
                scams = {(subj, c + 1): rcams[(subj, c + 1)] for c in range(N_CAMERAS)}  # cams of this subject
                scam_idx = [scams[(subj, c + 1)][-1] for c in range(N_CAMERAS)].index(cname)  # index of camera used
                the_cam = scams[(subj, scam_idx + 1)]  # <-- the camera used
                R, T, f, c, k, p, name = the_cam
                assert name == cname

                def cam2world_centered(data_3d_camframe):
                    data_3d_worldframe = cameras.camera_to_world_frame(data_3d_camframe.reshape((-1, 3)), R, T)
                    data_3d_worldframe = data_3d_worldframe.reshape((-1, N_JOINTS_H36M * 3))
                    # subtract root translation
                    return data_3d_worldframe - np.tile(data_3d_worldframe[:, :3], (1, N_JOINTS_H36M))

                # Apply inverse rotation and translation
                dec_out = cam2world_centered(dec_out)
                poses3d = cam2world_centered(poses3d)

    # Grab a random batch to visualize
    enc_in, dec_out, poses3d = map(np.vstack, [enc_in, dec_out, poses3d])
    idx = np.random.permutation(enc_in.shape[0])
    enc_in, dec_out, poses3d = enc_in[idx, :], dec_out[idx, :], poses3d[idx, :]

    # Visualize random samples
    import matplotlib.gridspec as gridspec

    # 1080p	= 1,920 x 1,080
    fig = plt.figure(figsize=(19.2, 10.8))

    gs1 = gridspec.GridSpec(5, 9)  # 5 rows, 9 columns
    gs1.update(wspace=-0.00, hspace=0.05)  # set the spacing between axes.
    plt.axis('off')

    subplot_idx, exidx = 1, 1
    nsamples = 15
    for i in np.arange(nsamples):
        # Plot 2d pose
        ax1 = plt.subplot(gs1[subplot_idx - 1])
        p2d = enc_in[exidx, :]
        viz.show2Dpose(p2d, ax1)
        ax1.invert_yaxis()

        # Plot 3d gt
        ax2 = plt.subplot(gs1[subplot_idx], projection='3d')
        p3d = dec_out[exidx, :]
        viz.show3Dpose(p3d, ax2)

        # Plot 3d predictions
        ax3 = plt.subplot(gs1[subplot_idx + 1], projection='3d')
        p3d = poses3d[exidx, :]
        viz.show3Dpose(p3d, ax3, lcolor="#9b59b6", rcolor="#2ecc71")

        exidx = exidx + 1
        subplot_idx = subplot_idx + 3

    plt.show()


def main(_):
    if FLAGS.sample:
        sample()
    else:
        train()


if __name__ == "__main__":
    tf.compat.v1.app.run()
